__all__ = ['AbstractRequestHandler', 'RequestHandler', 'TokenBasedAuth', 'format_endpoint']

import requests as rqst
import requests.auth as auth
import requests.compat as urlcompat
import logging
from copy import copy
from abc import abstractmethod, ABC
from {{packageName}}.const import *

LOGGER = logging.getLogger(__name__)


def check_response(response: rqst.Response):
    if not response.ok:
        raise rqst.exceptions.HTTPError(f'Server returned {response.status_code}; raw exception:\n {response.text}')


def xstr(s):
    if s is None:
        return ''
    return str(s)


def format_endpoint(endpoint: str, **params):
    formatted_endpoint = copy(endpoint)
    for k, v in params.items():
        key_formatted = '{' + k + '}'
        if key_formatted in formatted_endpoint:
            formatted_endpoint = formatted_endpoint.replace(key_formatted, xstr(v))
    if 'fields' in params.keys():
        fields = params['fields']
        if fields:
            if not 'locator' in formatted_endpoint:
                formatted_endpoint += f'?fields={fields}'
            else:
                formatted_endpoint += f'&fields={fields}'
    return formatted_endpoint.replace('//', '/')  # in case any locator is empty


class TokenBasedAuth(auth.AuthBase):
    def __init__(self, token):
        self.token = token

    def __eq__(self, other):
        return self.token == getattr(other, 'token', None)

    def __ne__(self, other):
        return not self == other

    def __call__(self, r):
        r.headers['Authorization'] = f'Bearer {self.token}'
        return r


class AbstractRequestHandler(ABC):
    @abstractmethod
    def request(self, endpoint: str, http_method: str, body: str = None, **kwargs):
        pass


class RequestHandler(AbstractRequestHandler):
    def __init__(self, base_url: str, auth: auth.AuthBase, content_type: ContentTypeEnum = ContentTypeEnum.JSON):
        self._base_url = base_url
        self._auth = auth
        self._headers = {}
        self._headers[HeadersEnum.CONTENT_TYPE.value] = content_type.value
        self._headers[HeadersEnum.ACCEPT.value] = content_type.value

    def request(self, endpoint: str, http_method: str, body: str = None, extra_headers: dict = None, **kwargs) -> rqst.Response:
        prepped_endpoint = self._prepare_endpoint(endpoint)
        prepped_headers = self._prepare_headers(extra_headers)
        self._printout(prepped_endpoint, http_method, body=body, headers=prepped_headers)
        response = rqst.request(auth=self._auth, url=prepped_endpoint, method=http_method, headers=prepped_headers,
                                data=body, **kwargs)
        check_response(response)
        return response

    def _prepare_endpoint(self, endpoint: str):
        return urlcompat.urljoin(self._base_url, endpoint)

    def _prepare_headers(self, extra_headers: dict):
        output = self._headers
        if extra_headers:
            for header, value in extra_headers.items():
                output[header] = value
        return output

    def _printout(self, endpoint, method, body=None, headers=None):
        LOGGER.info(f'[{method.upper()}]{endpoint}')
        if headers is not None:
            LOGGER.info(f'Request headers: {headers}')
        if body is not None:
            LOGGER.info(f'Request data: {body}')